# coding=utf-8
# Copyright 2024 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Stanford dogs dataset."""

import collections
import os
import re
import xml.etree.ElementTree as ET

from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
import tensorflow_datasets.public_api as tfds

_URL = "http://vision.stanford.edu/aditya86/ImageNetDogs/main.html"

_IMAGES_URL = "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar"
_SPLIT_URL = "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"
_ANNOTATIONS_URL = (
    "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar"
)
_NAME_RE = re.compile(r"([\w-]*[/\\])*([\w]*.jpg)$")


class Builder(tfds.core.GeneratorBasedBuilder):
  """Stanford Dogs dataset."""

  # Version 0.2.0: Fix non-deterministic label names
  VERSION = tfds.core.Version("0.2.0")

  def _info(self):
    return self.dataset_info_from_configs(
        features=tfds.features.FeaturesDict({
            # Images are of varying size
            "image": tfds.features.Image(),
            "image/filename": tfds.features.Text(),
            "label": tfds.features.ClassLabel(num_classes=120),
            # Multiple bounding box per image
            "objects": tfds.features.Sequence(
                {
                    "bbox": tfds.features.BBoxFeature(),
                }
            ),
        }),
        supervised_keys=("image", "label"),
        homepage=_URL,
    )

  def _split_generators(self, dl_manager):
    images_path = dl_manager.download(_IMAGES_URL)
    split_path, annotation_path = dl_manager.download_and_extract(
        [_SPLIT_URL, _ANNOTATIONS_URL]
    )
    xml_file_list = collections.defaultdict(str)

    # Parsing the mat file which contains the list of train/test images
    scipy = tfds.core.lazy_imports.scipy

    def parse_mat_file(file_name):
      with tf.io.gfile.GFile(file_name, "rb") as f:
        parsed_mat_arr = scipy.io.loadmat(f, squeeze_me=True)
      file_list = [
          os.path.split(element)[-1] for element in parsed_mat_arr["file_list"]
      ]

      return file_list, parsed_mat_arr

    for fname in tf.io.gfile.listdir(split_path):
      # Train-test split using train_list.mat and test_list.mat
      full_file_name = os.path.join(split_path, fname)

      if "train" in fname:
        train_list, train_mat_arr = parse_mat_file(full_file_name)
        label_names = set(
            [  # Set to remove duplicates
                os.path.split(element)[-2].lower()  # Extract path/label/img.jpg
                for element in train_mat_arr["file_list"]
            ]
        )
      elif "test" in fname:
        test_list, _ = parse_mat_file(full_file_name)

    self.info.features["label"].names = sorted(label_names)

    for root, _, files in tf.io.gfile.walk(annotation_path):
      # Parsing the XML file which have the image annotations
      for fname in files:
        annotation_file_name = os.path.join(root, fname)
        with tf.io.gfile.GFile(annotation_file_name, "rb") as f:
          xml_file_list[fname] = ET.parse(f)

    return [
        tfds.core.SplitGenerator(
            name=tfds.Split.TRAIN,
            gen_kwargs={
                "archive": dl_manager.iter_archive(images_path),
                "file_names": train_list,
                "annotation_files": xml_file_list,
            },
        ),
        tfds.core.SplitGenerator(
            name=tfds.Split.TEST,
            gen_kwargs={
                "archive": dl_manager.iter_archive(images_path),
                "file_names": test_list,
                "annotation_files": xml_file_list,
            },
        ),
    ]

  def _generate_examples(self, archive, file_names, annotation_files):
    """Generate dog images, labels, bbox attributes given the directory path.

    Args:
      archive: object that iterates over the zip file_names : list of train/test
        image file names obtained from mat file annotation_files : dict of image
        file names and their xml object

    Yields:
      Image path, Image file name, its corresponding label and
      bounding box values
    """
    bbox_attrib = ["xmin", "xmax", "ymin", "ymax", "width", "height"]

    for fname, fobj in archive:
      res = _NAME_RE.match(fname)
      if not res or (os.path.split(fname)[-1] not in file_names):
        continue

      label = res.group(1)[:-1].lower()
      file_name = res.group(2)
      attributes = collections.defaultdict(list)
      for element in annotation_files[file_name.split(".")[0]].iter():
        # Extract necessary Bbox attributes from XML file
        if element.tag.strip() in bbox_attrib:
          attributes[element.tag.strip()].append(float(element.text.strip()))

      # BBox attributes in range of 0.0 to 1.0
      def normalize_bbox(bbox_side, image_side):
        return min(bbox_side / image_side, 1.0)

      def build_box(attributes, n):
        return tfds.features.BBox(
            ymin=normalize_bbox(attributes["ymin"][n], attributes["height"][0]),
            xmin=normalize_bbox(attributes["xmin"][n], attributes["width"][0]),
            ymax=normalize_bbox(attributes["ymax"][n], attributes["height"][0]),
            xmax=normalize_bbox(attributes["xmax"][n], attributes["width"][0]),
        )

      yield fname, {
          "image": fobj,
          "image/filename": fname,
          "label": label,
          "objects": [
              {"bbox": build_box(attributes, n)}
              for n in range(len(attributes["xmin"]))
          ],
      }
